import argparse
import os

import yaml


class FunctionTag:
    def __init__(self, value):
        self.value = value


def prompt_func(mode, lang):
    prompt_map = {
        "prompt_1": "Your task is to answer a question given a context."
        "Make sure you respond with the shortest span containing the answer in the context.\n"
        "Question: {{question_lang}}\n"
        "Context: {{context}}\n"
        "Answer:",
        "prompt_2": f"Your task is to answer a question given a context. The question is in {lang}, while the context is in English or French."
        "Make sure you respond with the shortest span in the context that contains the answer.\n"
        "Question: {{question_lang}}\n"
        "Context: {{context}}\n"
        "Answer:",
        "prompt_3": "Given the context, provide the answer to the following question."
        "Ensure your response is concise and directly from the context.\n"
        "Question: {{question_lang}}\n"
        "Context: {{context}}\n"
        "Answer:",
        "prompt_4": "You are an AI assistant and your task is to answer the question based on the provided context."
        "Your answer should be the shortest span that contains the answer within the context.\n"
        "Question: {{question_lang}}\n"
        "Context: {{context}}\n"
        "Answer:",
        "prompt_5": "Using the context, find the answer to the question."
        "Respond with the briefest span that includes the answer from the context.\n"
        "Question: {{question_lang}}\n"
        "Context: {{context}}\n"
        "Answer:",
    }
    return prompt_map[mode]


def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
    """
    Generate a yaml file for each language.

    :param output_dir: The directory to output the files to.
    :param overwrite: Whether to overwrite files if they already exist.
    """
    err = []
    languages = {
        "bem": "Bemba",
        "fon": "Fon",
        "hau": "Hausa",
        "ibo": "Igbo",
        "kin": "Kinyarwanda",
        "swa": "Swahili",
        "twi": "Twi",
        "wol": "Wolof",
        "yor": "Yoruba",
        "zul": "Zulu",
    }

    for lang in languages.keys():
        try:
            file_name = f"afriqa_{lang}.yaml"
            task_name = f"afriqa_{lang}_{mode}"
            yaml_template = "afriqa"
            yaml_details = {
                "include": yaml_template,
                "task": task_name,
                "dataset_name": lang,
                "doc_to_text": prompt_func(mode, languages[lang]),
            }
            file_path = os.path.join(output_dir, mode)
            os.makedirs(file_path, exist_ok=True)

            with open(
                f"{output_dir}/{mode}/{file_name}",
                "w" if overwrite else "x",
                encoding="utf8",
            ) as f:
                f.write("# Generated by utils.py\n")
                yaml.dump(
                    yaml_details,
                    f,
                    allow_unicode=True,
                )
        except FileExistsError:
            err.append(file_name)

    if len(err) > 0:
        raise FileExistsError(
            "Files were not created because they already exist (use --overwrite flag):"
            f" {', '.join(err)}"
        )


def main() -> None:
    """Parse CLI args and generate language-specific yaml files."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--overwrite",
        default=True,
        action="store_true",
        help="Overwrite files if they already exist",
    )
    parser.add_argument(
        "--output-dir",
        default="./",
        help="Directory to write yaml files to",
    )
    parser.add_argument(
        "--mode",
        default="prompt_1",
        choices=["prompt_1", "prompt_2", "prompt_3", "prompt_4", "prompt_5"],
        help="Prompt number",
    )
    args = parser.parse_args()

    gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)


if __name__ == "__main__":
    main()
